{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "\n",
    "import os\n",
    "import random\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transforms = transforms.Compose([\n",
    "                           transforms.RandomHorizontalFlip(),\n",
    "                           transforms.RandomRotation(10),\n",
    "                           transforms.RandomCrop((224, 224), pad_if_needed=True),\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))\n",
    "                       ])\n",
    "\n",
    "test_transforms = transforms.Compose([\n",
    "                           transforms.CenterCrop((224, 224)),\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))\n",
    "                       ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = datasets.ImageFolder('data/dogs-vs-cats/train', train_transforms)\n",
    "valid_data = datasets.ImageFolder('data/dogs-vs-cats/valid', test_transforms)\n",
    "test_data = datasets.ImageFolder('data/dogs-vs-cats/test', test_transforms)\n",
    "\n",
    "#import os\n",
    "\n",
    "#print(len(os.listdir('data/dogs-vs-cats/train')))\n",
    "\n",
    "#n_train_examples = int(len(train_data)*0.9)\n",
    "#n_valid_examples = n_test_examples = len(train_data) - n_train_examples\n",
    "\n",
    "#train_data, valid_data = torch.utils.data.random_split(train_data, [n_train_examples, n_valid_examples])\n",
    "#train_data, test_data = torch.utils.data.random_split(train_data, [n_train_examples-n_valid_examples, n_test_examples])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "https://github.com/facebook/fb.resnet.torch/issues/180\n",
    "https://github.com/bamos/densenet.pytorch/blob/master/compute-cifar10-mean.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training examples: 20000\n",
      "Number of validation examples: 2500\n",
      "Number of testing examples: 2500\n"
     ]
    }
   ],
   "source": [
    "print(f'Number of training examples: {len(train_data)}')\n",
    "print(f'Number of validation examples: {len(valid_data)}')\n",
    "print(f'Number of testing examples: {len(test_data)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "train_iterator = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)\n",
    "valid_iterator = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE)\n",
    "test_iterator = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "https://discuss.pytorch.org/t/why-does-the-resnet-model-given-by-pytorch-omit-biases-from-the-convolutional-layer/10990/4\n",
    "https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "\n",
    "model = models.resnet18(pretrained=True).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
      "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=512, out_features=1000, bias=True)\n"
     ]
    }
   ],
   "source": [
    "print(model.fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc = nn.Linear(in_features=512, out_features=2).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_accuracy(fx, y):\n",
    "    preds = fx.max(1, keepdim=True)[1]\n",
    "    correct = preds.eq(y.view_as(preds)).sum()\n",
    "    acc = correct.float()/preds.shape[0]\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, iterator, optimizer, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for (x, y) in iterator:\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "                \n",
    "        fx = model(x)\n",
    "        \n",
    "        loss = criterion(fx, y)\n",
    "        \n",
    "        acc = calculate_accuracy(fx, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, device, iterator, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for (x, y) in iterator:\n",
    "\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            fx = model(x)\n",
    "\n",
    "            loss = criterion(fx, y)\n",
    "\n",
    "            acc = calculate_accuracy(fx, y)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Epoch: 01 | Train Loss: 0.198 | Train Acc: 91.95% | Val. Loss: 0.089 | Val. Acc: 96.80% |\n",
      "| Epoch: 02 | Train Loss: 0.136 | Train Acc: 94.40% | Val. Loss: 0.069 | Val. Acc: 97.50% |\n",
      "| Epoch: 03 | Train Loss: 0.128 | Train Acc: 94.68% | Val. Loss: 0.059 | Val. Acc: 97.70% |\n",
      "| Epoch: 04 | Train Loss: 0.119 | Train Acc: 95.03% | Val. Loss: 0.070 | Val. Acc: 97.30% |\n",
      "| Epoch: 05 | Train Loss: 0.118 | Train Acc: 94.95% | Val. Loss: 0.057 | Val. Acc: 97.73% |\n",
      "| Epoch: 06 | Train Loss: 0.121 | Train Acc: 94.95% | Val. Loss: 0.056 | Val. Acc: 97.70% |\n",
      "| Epoch: 07 | Train Loss: 0.117 | Train Acc: 95.11% | Val. Loss: 0.063 | Val. Acc: 97.46% |\n",
      "| Epoch: 08 | Train Loss: 0.110 | Train Acc: 95.44% | Val. Loss: 0.052 | Val. Acc: 97.93% |\n",
      "| Epoch: 09 | Train Loss: 0.116 | Train Acc: 95.14% | Val. Loss: 0.056 | Val. Acc: 97.77% |\n",
      "| Epoch: 10 | Train Loss: 0.114 | Train Acc: 95.36% | Val. Loss: 0.063 | Val. Acc: 97.46% |\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 10\n",
    "SAVE_DIR = 'models'\n",
    "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'resnet18-dogs-vs-cats.pt')\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "if not os.path.isdir(f'{SAVE_DIR}'):\n",
    "    os.makedirs(f'{SAVE_DIR}')\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    train_loss, train_acc = train(model, device, train_iterator, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, device, valid_iterator, criterion)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), MODEL_SAVE_PATH)\n",
    "    \n",
    "    print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Test Loss: 0.052 | Test Acc: 97.93% |\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load(MODEL_SAVE_PATH))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, device, valid_iterator, criterion)\n",
    "\n",
    "print(f'| Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:05.2f}% |')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
